"""
number of resnet/densenet/pool
                    ----add resnet/densenet/pool
                    ----remove resnet/densenet/pool
properties of resnet/dense/pool
                    ----in_channel/out_channel of resnet
                    ----amount in one resne

                    t
                    ----in_channel/out_channel of densenet
                    ----amount in one densenet
                    ----k of densenet
                    ----pooling type

firstly, three basic operations:add, remove, alter
secondly, the particular operation is chosen based on a probability
"""
import random
import numpy as np
import copy
from utils import StatusUpdateTool, Utils


class CrossoverAndMutation(object):
    def __init__(self, prob_crossover, prob_mutation, _log, individuals,mutation_pro, _params=None):
        self.prob_crossover = prob_crossover
        self.prob_mutation = prob_mutation
        self.individuals = individuals
        self.params = _params  # storing other parameters if needed, such as the index for SXB and polynomial mutation
        self.log = _log
        self.offspring = []
        self.parents=[]
        self.Mutation_pro=mutation_pro

    def process(self):
        #进行交叉操作产生子代
        crossover = Crossover(self.individuals, self.prob_crossover, self.log)
        offspring,parents = crossover.do_crossover()
        self.offspring = offspring
        self.parents = parents
        #将子代保存   交叉过后生成N个新个体（子代）
        Utils.save_population_after_crossover(self.individuals_to_string(), self.params['gen_no'])#将新的个体放在crossover.txt里面
        #进行变异操作
        mutation = Mutation(self.offspring, self.prob_mutation, self.log,self.Mutation_pro)
        mutation.do_mutation()
        for i, indi in enumerate(self.offspring):
            indi_no = 'indi%02d%02d' % (self.params['gen_no'], i)
            indi.id = indi_no#每个个体的ID为‘indi代数序号’
        #将变异后的的子代保存
        Utils.save_population_after_mutation(self.individuals_to_string(), self.params['gen_no'])
        return offspring

    def individuals_to_string(self):
        _str = []
        for ind in self.offspring:
            _str.append(str(ind))
            _str.append('-' * 100)
        return '\n'.join(_str)


class Crossover(object):
    def __init__(self, individuals, prob_, _log):
        self.individuals = individuals
        self.prob = prob_  # 交叉概率
        self.log = _log
        self.pool_limit = StatusUpdateTool.get_pool_limit()[1]

    def _choose_one_parent(self):  # 从两个个体中选择适应度值大的一个个体
        count_ = len(self.individuals)
        idx1 = int(np.floor(np.random.random() * count_))
        idx2 = int(np.floor(np.random.random() * count_))
        while idx2 == idx1:
            idx2 = int(np.floor(np.random.random() * count_))

        if self.individuals[idx1].acc > self.individuals[idx2].acc:
            return idx1
        else:
            return idx2

    """
    binary tournament selection
    """

    def _choose_two_diff_parents(self):  # 选择两个个体进行交叉
        idx1 = self._choose_one_parent()
        idx2 = self._choose_one_parent()
        while idx2 == idx1:
            idx2 = self._choose_one_parent()

        assert idx1 < len(self.individuals)
        assert idx2 < len(self.individuals)
        return idx1, idx2

    """
    calculate the number of pooling units after the crossover is done
    """

    def _calculate_pool_numbers(self, parent1, parent2):  # 生成交叉点位置，并计算若生成子代，每个个体的中池化单元的个数
        t1, t2 = 0, 0
        for unit in parent1.units:  # 计算第一个父代个体中池化单元的个数
            if unit.type == 2:
                t1 += 1
        for unit in parent2.units:  # 计算第二个父代个体池化单元的个数
            if unit.type == 2:
                t2 += 1

        len1, len2 = len(parent1.units), len(parent2.units)  # 计算两个父代个体中每个个体有多少个单元
        pos1, pos2 = int(np.floor(np.random.random() * len1)), int(
            np.floor(np.random.random() * len2))  # 分别在两个父代个体中选择交叉点
        assert pos1 < len1  # 交叉点位置大于个体总长度报错
        assert pos2 < len2  # 同上
        p1_left, p1_right, p2_left, p2_right = 0, 0, 0, 0
        for i in range(0, pos1):
            if parent1.units[i].type == 2:
                p1_left += 1
        for i in range(pos1, len1):
            if parent1.units[i].type == 2:
                p1_right += 1

        for i in range(0, pos2):
            if parent2.units[i].type == 2:
                p2_left += 1
        for i in range(pos2, len2):
            if parent2.units[i].type == 2:
                p2_right += 1

        new_pool_number1 = p1_left + p2_right  # 新生成的一个子代个体中池化单元的个数
        new_pool_number2 = p2_left + p1_right  # 同上
        return pos1, pos2, new_pool_number1, new_pool_number2

    def do_crossover(self):
        _stat_param = {'offspring_new': 0, 'offspring_from_parent': 0}
        new_offspring_list = []   #RU，PU不改变通道数，只有DU通过图片堆叠的形式改变通道数
        parent_list= []
        for _ in range(len(self.individuals) // 2):
            ind1, ind2 = self._choose_two_diff_parents()  # 选择两个父代 返回的是两个序号
            parent1, parent2 = copy.deepcopy(self.individuals[ind1]), copy.deepcopy(self.individuals[ind2])  # 深拷贝
            parent_list.append(parent1)
            parent_list.append(parent2)
            p_ = random.random()  # 交叉概率
            if p_ < self.prob:
                _stat_param['offspring_new'] += 2
                """
                exchange their units from these parent individuals, the exchanged units must satisfy
                --- the number of pooling layer should not be more than the predefined setting
                --- if their is no changing after this crossover, keep the original acc -- a mutation should be given [to do---]
                """
                first_begin_is_pool, second_begin_is_pool = True, True  # 指明新生成的两个子代开头是池化
                while first_begin_is_pool is True or second_begin_is_pool is True:
                    pos1, pos2, pool_len1, pool_len2 = self._calculate_pool_numbers(parent1,
                                                                                    parent2)  # 计算两个父代个体的交叉点和新生成的两个子代中池化单元个数是否满足条件
                    try_count = 1
                    while pool_len1 > self.pool_limit or pool_len2 > self.pool_limit:  # 如果不满足条件重新选择交叉点并进行新生子代池化单元个数判断，直到满足条件为止
                        pos1, pos2, pool_len1, pool_len2 = self._calculate_pool_numbers(parent1, parent2)
                        try_count += 1
                        self.log.warn('The %d-th try to find the position for do crossover' % (try_count)) #尝试寻找交叉点
                    self.log.info('Position %d for %s, positions %d for %s' % (pos1, parent1.id, pos2, parent2.id))
                    unit_list1, unit_list2 = [], []
                    for i in range(0, pos1):  # 收集父代一交叉点前半部分
                        unit_list1.append(parent1.units[i])
                    for i in range(pos2, len(parent2.units)):  # 收集父代二交叉点后半部分，生成子代一
                        unit_list1.append(parent2.units[i])

                    for i in range(0, pos2):
                        unit_list2.append(parent2.units[i])  # 收集父代二交叉点前半部分
                    for i in range(pos1, len(parent1.units)):
                        unit_list2.append(parent1.units[i])  # 收集父代一交叉点后半部分，生成子代二
                    first_begin_is_pool = True if unit_list1[0].type == 2 else False  # 防止新产生的子代刚开始的单位为池化
                    second_begin_is_pool = True if unit_list2[0].type == 2 else False  # 同上

                    if first_begin_is_pool is True:
                        self.log.warn('Crossovered individual#1 starts with a pooling layer, redo...')  # 报错
                    if second_begin_is_pool is True:
                        self.log.warn('Crossovered individual#2 starts with a pooling layer, redo...')  # 同上

                    # reorder the number（序号？） of each unit based on its order in the list
                    for i, unit in enumerate(unit_list1):
                        unit.number = i
                    for i, unit in enumerate(unit_list2):
                        unit.number = i

                    # re-adjust the in_channel of the next layer(重新调整下一层的输入通道数)
                    last_output_from_list1 = 0
                    if pos1 == 0:#如果交叉点在头部，那么需要获取输入图片的通道数
                        last_output_from_list1 = StatusUpdateTool.get_input_channel()
                        j = 0
                        i = -1
                    else:
                        for i in range(pos1 - 1, -1, -1):
                            if unit_list1[i].type == 1 or unit_list1[i].type == 3:
                                last_output_from_list1 = unit_list1[i].out_channel
                                break

                    keep_out_channel = last_output_from_list1 # （对子代一的操作）keep_out_channel记录上一部分的输出通道数
                    for j in range(pos1, len(unit_list1)):
                        if unit_list1[j].type == 1 or unit_list1[j].type == 3:
                            self.log.info(
                                'Change the input channel of unit at %d to %d that is the output channel of unit at %d in %s' % (
                                    j, keep_out_channel, i, parent1.id))
                            unit_list1[j].in_channel = keep_out_channel
                            if unit_list1[j].type == 1:
                                break  #resnet块并不改变输出通道数
                            elif unit_list1[j].type == 3:
                                estimated_out_channel = unit_list1[j].k * unit_list1[j].amount + unit_list1[
                                    j].in_channel  #计算densenet块的输出
                                if estimated_out_channel > unit_list1[j].out_channel:
                                    unit_list1[j].out_channel=estimated_out_channel
                                else:
                                    self.log.info(
                                        'Due to the above change, unit at %d changes its output channel from %d to %d' % (
                                            j, unit_list1[j].out_channel, estimated_out_channel))
                                    unit_list1[j].out_channel = estimated_out_channel
                                    keep_out_channel = estimated_out_channel

                    last_output_from_list2 = 0
                    if pos2 == 0:
                        last_output_from_list2 = StatusUpdateTool.get_input_channel()
                        j = 0
                        i = -1
                    else:
                        for i in range(pos2 - 1, -1, -1):
                            if unit_list2[i].type == 1 or unit_list2[i].type == 3:
                                last_output_from_list2 = unit_list2[i].out_channel
                                break

                    keep_out_channel = last_output_from_list2 #调整子代2的个体交叉点的输出通道
                    for j in range(pos2, len(unit_list2)):  # （对子代二的操作）keep_out_channel记录上一部分的输出通道数
                        if unit_list2[j].type == 1 or unit_list2[j].type == 3:
                            self.log.info(
                                'Change the input channel of unit at %d to %d that is the output channel of unit at %d in %s' % (
                                    j, keep_out_channel, i, parent2.id))
                            unit_list2[j].in_channel = keep_out_channel
                            if unit_list2[j].type == 1:
                                break
                            elif unit_list2[j].type == 3:
                                estimated_out_channel = unit_list2[j].k * unit_list2[j].amount + unit_list2[
                                    j].in_channel
                                if estimated_out_channel > unit_list2[j].out_channel:
                                    #break
                                    unit_list2[j].out_channel = estimated_out_channel
                                else:
                                    self.log.info(
                                        'Due to the above change, unit at %d changes its output channel from %d to %d' % (
                                            j, unit_list2[j].out_channel, estimated_out_channel))
                                    unit_list2[j].out_channel = estimated_out_channel
                                    keep_out_channel = estimated_out_channel

                    parent1.units = unit_list1
                    parent2.units = unit_list2
                    offspring1, offspring2 = parent1, parent2
                    offspring1.reset_acc()#重置子代1和子代2的精度
                    offspring2.reset_acc()
                    new_offspring_list.append(offspring1)
                    new_offspring_list.append(offspring2)
            else:
                _stat_param['offspring_from_parent'] += 2
                new_offspring_list.append(parent1)
                new_offspring_list.append(parent2)

            self.log.info('CROSSOVER-%d offspring are generated, new:%d, others:%d' % (
                len(new_offspring_list), _stat_param['offspring_new'], _stat_param['offspring_from_parent']))
        return new_offspring_list,parent_list


class Mutation(object):

    def __init__(self, individuals, prob_, _log,mutation_pro):
        self.individuals = individuals
        self.prob = prob_
        self.log = _log
        self.Mutation_pro=mutation_pro


    def do_mutation(self):
        _stat_param = {'offspring_new': 0, 'offspring_from_parent': 0, 'ADD': 0, 'REMOVE': 0, 'ALTER': 0,
                       'RESNET_OUT_CHANNEL': 0, 'RESNET_AMOUNT': 0, 'DENSENET_AMOUNT': 0, 'POOLING_TYPE': 0,'Mutation_AMOUNT':3}#字典
        #设定mutation的类型数量
        mutation_list = self.Mutation_pro #[0.4,0.3,0.3]
        for indi in self.individuals:
            p_ = random.random()
            if p_ < self.prob:#选择变异策略的地方
                _stat_param['offspring_new'] += 1
                mutation_type = self.select_mutation_type(mutation_list)
                if mutation_type == 0:
                    _stat_param['ADD'] += 1
                    indi.Mutation_type = 0#确定此个体选择的变异类型
                    self.do_add_unit_mutation(indi)
                elif mutation_type == 1:
                    _stat_param['REMOVE'] += 1
                    indi.Mutation_type = 1  # 确定此个体选择的变异类型
                    self.do_remove_unit_mutation(indi)
                elif mutation_type == 2:
                    indi.Mutation_type = 2  # 确定此个体选择的变异类型
                    mutation_p_type, mutation_p_count = self.do_alter_mutation(indi)
                    _stat_param[mutation_p_type] = mutation_p_count + _stat_param[mutation_p_type]
                    _stat_param['ALTER'] += mutation_p_count
                    if mutation_p_count == 0:
                        _stat_param['offspring_new'] -= 1
                        _stat_param['offspring_from_parent'] += 1
                else:
                    raise TypeError('Error mutation type :%d, validate range:0-4' % (mutation_type))#好像应该是0-2吧
            else:
                indi.Mutation_type=3#表示此个体没有发生变异
                _stat_param['offspring_from_parent'] += 1
        self.log.info(
            'MUTATION-mutated individuals:%d[ADD:%2d,REMOVE:%2d,ALTER:%2d,RESNET_OUT_CHANNEL:%2d, RESNET_AMOUNT:%2d, DENSENET_AMOUNT:%2d, POOLING_TYPE:%2d, no_changes:%d' % (
                _stat_param['offspring_new'],
                _stat_param['ADD'], _stat_param['REMOVE'], _stat_param['ALTER'], _stat_param['RESNET_OUT_CHANNEL'],
                _stat_param['RESNET_AMOUNT'], _stat_param['DENSENET_AMOUNT'], _stat_param['POOLING_TYPE'],
                _stat_param['offspring_from_parent']))
        return indi

    # def do_not_mutation(self,indi):
    #     pass

    def do_add_unit_mutation(self, indi):
        self.log.info('Do the ADD mutation for indi:%s' % (indi.id))
        """
        choose one position to add one unit, adding one resnet/densenet or pooling unit is determined by a probability of 1/3.
        However, if the maximal number of pooling units have been added into the current individual, only
        resnet/densenet unit will be add here
        """
        # determine the position where a unit would be added,获取变异位置
        mutation_position = int(np.floor(np.random.random() * len(indi.units)))
        self.log.info('Mutation position occurs at %d' % (mutation_position))
        # determine the unit type for adding
        u_ = random.random()
        if u_ < 0.333:
            type_ = 1
        elif u_ < 0.666:
            type_ = 2
        else:
            type_ = 3
        type_string_list = ['RESNET', 'POOLING', 'DENSENET']
        self.log.info('A %s unit would be added due to the probability of %.2f' % (type_string_list[type_ - 1], u_))
        #计算池化单元的个数
        if type_ == 2: #如果池化层数量超过限制重新进行添加操作的选择块
            num_exist_pool_units = 0
            for unit in indi.units:
                if unit.type == 2:
                    num_exist_pool_units += 1
            if num_exist_pool_units > StatusUpdateTool.get_pool_limit()[1] - 1:
                u_ = random.random()
                type_ = 1 if u_ < 0.5 else 3
                self.log.info(
                    'The added unit is changed to %s because the existing number of POOLING exceeds %d, limit size:%d' % (
                        'RESNET' if type_ == 1 else 'DENSENET', num_exist_pool_units,
                        StatusUpdateTool.get_pool_limit()[1]))

        # do the details#增加新的unit
        if type_ == 2:#如果是池化层那么就初始化一个池化层
            add_unit = indi.init_a_pool(mutation_position + 1, _max_or_avg=None)
        else:#如果不是池化层
            for i in range(mutation_position, -1, -1):
                if indi.units[i].type == 1 or indi.units[i].type == 3:
                    _in_channel = indi.units[i].out_channel
                    break
            if type_ == 1:
                add_unit = indi.init_a_resnet(mutation_position + 1, _amount=None, _in_channel=_in_channel,
                                              _out_channel=None)
            if type_ == 3:
                add_unit = indi.init_a_densenet(mutation_position + 1, _amount=None, _k=None, _max_input_channel=None,
                                                _in_channel=_in_channel)

            keep_out_channel = add_unit.out_channel#改变新单元后面第一个RBU或DBU的输入大小
            for i in range(mutation_position + 1, len(indi.units)):
                if indi.units[i].type == 1 or indi.units[i].type == 3:
                    self.log.info('Due to the above mutation, unit at %d changes its input channel from %d to %d' % (
                        i, indi.units[i].in_channel, keep_out_channel))
                    indi.units[i].in_channel = keep_out_channel
                    if indi.units[i].type == 1:
                        break
                    elif indi.units[i].type == 3:
                        estimated_out_channel = indi.units[i].k * indi.units[i].amount + indi.units[i].in_channel
                        if estimated_out_channel > indi.units[i].out_channel:
                            indi.units[i].out_channel=estimated_out_channel
                        else:
                            self.log.info(
                                'Due to the above mutation, unit at %d changes its output channel from %d to %d' % (
                                    i, indi.units[i].out_channel, estimated_out_channel))
                            indi.units[i].out_channel = estimated_out_channel
                            keep_out_channel = estimated_out_channel

        new_unit_list = []
        # add to the new list and update the number
        for i in range(mutation_position + 1):
            new_unit_list.append(indi.units[i])
        new_unit_list.append(add_unit)
        for i in range(mutation_position + 1, len(indi.units)):
            unit = indi.units[i]
            unit.number += 1
            new_unit_list.append(unit)
        indi.number_id += 1
        indi.units = new_unit_list#最终变异结果
        indi.reset_acc()

    def do_remove_unit_mutation(self, indi):
        self.log.info('Do the REMOVE mutation for indi:%s' % (indi.id))
        if len(indi.units) > 1:
            mutation_position = int(
                np.floor(np.random.random() * (len(indi.units) - 1))) + 1  # the first unit would not be removed，第一个单元不能被删除
            self.log.info('Mutation position occurs at %d' % (mutation_position))#确定突变位置
            if indi.units[mutation_position].type == 1 or indi.units[mutation_position].type == 3:
                keep_out_channel = indi.units[mutation_position].in_channel
                for i in range(mutation_position + 1, len(indi.units)):
                    if indi.units[i].type == 1 or indi.units[i].type == 3:
                        self.log.info(
                            'Due to the above mutation, unit at %d changes its input channel from %d to %d' % (
                                i, indi.units[i].in_channel, keep_out_channel))
                        indi.units[i].in_channel = keep_out_channel
                        if indi.units[i].type == 1:
                            break
                        elif indi.units[i].type == 3:
                            estimated_out_channel = indi.units[i].k * indi.units[i].amount + indi.units[i].in_channel
                            if estimated_out_channel > indi.units[i].out_channel:
                                indi.units[i].out_channel = estimated_out_channel
                            else:
                                self.log.info(
                                    'Due to the above mutation, unit at %d changes its output channel from %d to %d' % (
                                        i, indi.units[i].out_channel, estimated_out_channel))
                                indi.units[i].out_channel = estimated_out_channel
                                keep_out_channel = estimated_out_channel

            else:
                self.log.info('A POOLING at %d is removed' % (mutation_position))
            new_unit_list = []
            for i in range(mutation_position):#在生成new_unit_list时并未加入indi.units[mutation_position]，即新的new_unit_list已将indi.units[mutation_position]删除
                new_unit_list.append(indi.units[i])
            for i in range(mutation_position + 1, len(indi.units)):
                unit = indi.units[i]
                unit.number -= 1
                new_unit_list.append(unit)
            indi.number_id -= 1
            indi.units = new_unit_list
            indi.reset_acc()
        else:
            self.log.warn('REMOVE mutation can not be performed due to it has only one unit')

    def do_alter_mutation(self, indi):#返回变异的类型和是否发生变异（发生变异为1，不发生为2，虽确定了变异类型但不一定变异）
        """
                ----out_channel of resnet
                ----amount in one resnet
                ----amount in one densenet
                ----pooling type
        """
        #确定要变异的单元
        self.log.info('Do the ALTER mutation for indi:%s' % (indi.id))
        mutation_position = int(np.floor(np.random.random() * len(indi.units)))
        mutation_unit = indi.units[mutation_position]
        mutation_unit_name = ''
        #确定其属于什么样的单元
        if mutation_unit.type == 1:
            mutation_unit_name = 'RESNET'
        elif mutation_unit.type == 2:
            mutation_unit_name = 'POOLING'
        else:
            mutation_unit_name = 'DENSENET'
        self.log.info(
            'Do the %s mutation for indi:%s at position %d' % (mutation_unit_name, indi.id, mutation_position))
        #根据单元的类型进行变异
        mutation_p_type = ''
        mutation_p_count = ''
        if mutation_unit.type == 1:
            mutation_p_type, mutation_p_count = self.do_alter_resnet_mutation(mutation_position, indi)#返回变异的类型，
                                                             # 是否变异（1为变异，2为不变（虽确定了变异类型，但是不一定进行变异））
        elif mutation_unit.type == 2:
            mutation_p_type, mutation_p_count = self.do_alter_pooling_mutation(mutation_position, indi)
        else:
            mutation_p_type, mutation_p_count = self.do_alter_densenet_mutation(mutation_position, indi)

        return mutation_p_type, mutation_p_count

    def do_alter_resnet_mutation(self, position, indi):
        """
        ----out_channel of resnet
        ----amount in one resnet
        """
        mutation_p_type = ''
        mutation_p_count = 0

        u_ = random.random()
        if u_ < 0.5:
            mutation_p_type = 'RESNET_OUT_CHANNEL'
            channel_list = StatusUpdateTool().get_output_channel()
            index_ = int(np.floor(np.random.random() * len(channel_list)))
            if indi.units[position].out_channel != channel_list[index_]:
                self.log.info('Unit at %d changes its output channel from %d to %d' % (
                    position, indi.units[position].out_channel, channel_list[index_]))
                indi.units[position].out_channel = channel_list[index_]
                #改变输出通道数后要将下一个单元的输入通道数改变
                keep_out_channel = channel_list[index_]
                for i in range(position + 1, len(indi.units)):
                    if indi.units[i].type == 1 or indi.units[i].type == 3:
                        self.log.info('Due to above, the unit at %d should change its input channel from %d to %d' % (
                            i, indi.units[i].in_channel, keep_out_channel))
                        indi.units[i].in_channel = keep_out_channel
                        #如果是RBU就直接退出，如果是DBU需要将其输出通道数改变
                        if indi.units[i].type == 1:
                            break
                        elif indi.units[i].type == 3:
                            estimated_out_channel = indi.units[i].k * indi.units[i].amount + indi.units[i].in_channel
                            if estimated_out_channel > indi.units[i].out_channel:
                                indi.units[i].out_channel=estimated_out_channel
                            else:
                                self.log.info(
                                    'Due to the above mutation, unit at %d changes its output channel from %d to %d' % (
                                        i, indi.units[i].out_channel, estimated_out_channel))
                                indi.units[i].out_channel = estimated_out_channel
                                keep_out_channel = estimated_out_channel

                mutation_p_count = 1
                #重置精度
                indi.reset_acc()
        else:
            mutation_p_type = 'RESNET_AMOUNT'
            min_resnet_unit, max_resnet_unit = StatusUpdateTool.get_resnet_unit_length_limit()
            amount = np.random.randint(min_resnet_unit, max_resnet_unit)
            if amount != indi.units[position].amount:
                self.log.info(
                    'Unit at %d changes its amount from %d to %d' % (position, indi.units[position].amount, amount))
                indi.units[position].amount = amount
                mutation_p_count = 1
                #重置精度
                indi.reset_acc()
        return mutation_p_type, mutation_p_count

    def do_alter_densenet_mutation(self, position, indi):
        mutation_p_type = 'DENSENET_AMOUNT'
        mutation_p_count = 0
        #确定K值
        k = indi.units[position].k
        #根据K值确定amount范围
        if k == 12:
            _, amount_lower_limit, amount_upper_limit = StatusUpdateTool.get_densenet_k12()#128,3,10
        elif k == 20:
            _, amount_lower_limit, amount_upper_limit = StatusUpdateTool.get_densenet_k20()#64,3,10
        elif k == 40:
            _, amount_lower_limit, amount_upper_limit = StatusUpdateTool.get_densenet_k40()#32,3,5
        #生成amount
        amount = np.random.randint(amount_lower_limit, amount_upper_limit + 1)
        if amount != indi.units[position].amount:
            self.log.info(
                'Unit at %d changes its amount from %d to %d' % (position, indi.units[position].amount, amount))
            #若改变了amount，则DBU对应的输出通道数也要改变
            if indi.units[position].amount < amount:
                new_out_channel = (amount - indi.units[position].amount) * k + indi.units[position].out_channel
            else:
                new_out_channel = indi.units[position].out_channel - (indi.units[position].amount - amount) * k
            indi.units[position].amount = amount

            self.log.info('Due to the above mutation, unit at %d changes its output channel from %d to %d' % (
                position, indi.units[position].out_channel, new_out_channel))
            indi.units[position].out_channel = new_out_channel
            #紧随其后的输入通道也得变
            keep_out_channel = new_out_channel
            for i in range(position + 1, len(indi.units)):
                if indi.units[i].type == 1 or indi.units[i].type == 3:
                    self.log.info('Due to the above mutation, unit at %d changes its input channel from %d to %d' % (
                        i, indi.units[i].in_channel, keep_out_channel))
                    indi.units[i].in_channel = keep_out_channel
                    if indi.units[i].type == 1:
                        break
                    elif indi.units[i].type == 3:
                        estimated_out_channel = indi.units[i].k * indi.units[i].amount + indi.units[i].in_channel
                        if estimated_out_channel > indi.units[i].out_channel:
                            indi.units[i].out_channel=estimated_out_channel
                        else:
                            self.log.info(
                                'Due to the above mutation, unit at %d changes its output channel from %d to %d' % (
                                    i, indi.units[i].out_channel, estimated_out_channel))
                            indi.units[i].out_channel = estimated_out_channel
                            keep_out_channel = estimated_out_channel

            mutation_p_count = 1
            indi.reset_acc()
        return mutation_p_type, mutation_p_count

    def do_alter_pooling_mutation(self, position, indi):
        mutation_p_type = 'POOLING_TYPE'
        mutation_p_count = 1

        if indi.units[position].max_or_avg > 0.5:
            indi.units[position].max_or_avg = 0.2
            self.log.info('Pool type from avg_pool (>0.5) to max_pool (<0.5)')
        else:
            indi.units[position].max_or_avg = 0.75
            self.log.info('Pool type from max_pool (<0.5) to avg_pool (>0.5)')
        indi.reset_acc()
        return mutation_p_type, mutation_p_count

    def select_mutation_type(self, _a):
        a = np.asarray(_a)
        a = np.asarray(_a)
        k = 1
        idx = np.argsort(a)
        idx = idx[::-1]
        sort_a = a[idx]
        sum_a = np.sum(a).astype(np.float)
        selected_index = []
        for i in range(k):
            u = np.random.rand() * sum_a
            sum_ = 0
            for i in range(sort_a.shape[0]):
                sum_ += sort_a[i]
                if sum_ > u:
                    selected_index.append(idx[i])
                    break
        return selected_index[0]


if __name__ == '__main__':
    m = Mutation(None, None, None)
    m.do_mutation()
